import math
import operator
from functools import reduce

import numpy as np
import gym
from gym import error, spaces, utils
from .minigrid import OBJECT_TO_IDX, COLOR_TO_IDX, STATE_TO_IDX
from typing import List, Tuple

class ReseedWrapper(gym.core.Wrapper):
    """
    Wrapper to always regenerate an environment with the same set of seeds.
    This can be used to force an environment to always keep the same
    configuration when reset.
    """

    def __init__(self, env, seeds=[0], seed_idx=0):
        self.seeds = list(seeds)
        self.seed_idx = seed_idx
        super().__init__(env)

    def reset(self, **kwargs):
        seed = self.seeds[self.seed_idx]
        self.seed_idx = (self.seed_idx + 1) % len(self.seeds)
        self.env.seed(seed)
        return self.env.reset(**kwargs)

    def step(self, action):
        obs, reward, done, info = self.env.step(action)
        return obs, reward, done, info

class ActionBonus(gym.core.Wrapper):
    """
    Wrapper which adds an exploration bonus.
    This is a reward to encourage exploration of less
    visited (state,action) pairs.
    """

    def __init__(self, env):
        super().__init__(env)
        self.counts = {}

    def step(self, action):
        obs, reward, done, info = self.env.step(action)

        env = self.unwrapped
        tup = (tuple(env.agent_pos), env.agent_dir, action)

        # Get the count for this (s,a) pair
        pre_count = 0
        if tup in self.counts:
            pre_count = self.counts[tup]

        # Update the count for this (s,a) pair
        new_count = pre_count + 1
        self.counts[tup] = new_count

        bonus = 1 / math.sqrt(new_count)
        reward += bonus

        return obs, reward, done, info

    def reset(self, **kwargs):
        return self.env.reset(**kwargs)

class StateBonus(gym.core.Wrapper):
    """
    Adds an exploration bonus based on which positions
    are visited on the grid.
    """

    def __init__(self, env):
        super().__init__(env)
        self.counts = {}

    def step(self, action):
        obs, reward, done, info = self.env.step(action)

        # Tuple based on which we index the counts
        # We use the position after an update
        env = self.unwrapped
        tup = (tuple(env.agent_pos))

        # Get the count for this key
        pre_count = 0
        if tup in self.counts:
            pre_count = self.counts[tup]

        # Update the count for this key
        new_count = pre_count + 1
        self.counts[tup] = new_count

        bonus = 1 / math.sqrt(new_count)
        reward += bonus

        return obs, reward, done, info

    def reset(self, **kwargs):
        return self.env.reset(**kwargs)

class ImgObsWrapper(gym.core.ObservationWrapper):
    """
    Use the image as the only observation output, no language/mission.
    """

    def __init__(self, env):
        super().__init__(env)
        self.observation_space = env.observation_space.spaces['image']

    def observation(self, obs):
        return obs['image']

class OneHotPartialObsWrapper(gym.core.ObservationWrapper):
    """
    Wrapper to get a one-hot encoding of a partially observable
    agent view as observation.
    """

    def __init__(self, env, tile_size=8):
        super().__init__(env)

        self.tile_size = tile_size

        obs_shape = env.observation_space['image'].shape

        # Number of bits per cell
        num_bits = len(OBJECT_TO_IDX) + len(COLOR_TO_IDX) + len(STATE_TO_IDX)

        self.observation_space.spaces["image"] = spaces.Box(
            low=0,
            high=255,
            shape=(obs_shape[0], obs_shape[1], num_bits),
            dtype='uint8'
        )

    def observation(self, obs):
        img = obs['image']
        out = np.zeros(self.observation_space.spaces['image'].shape, dtype='uint8')

        for i in range(img.shape[0]):
            for j in range(img.shape[1]):
                type = img[i, j, 0]
                color = img[i, j, 1]
                state = img[i, j, 2]

                out[i, j, type] = 1
                out[i, j, len(OBJECT_TO_IDX) + color] = 1
                out[i, j, len(OBJECT_TO_IDX) + len(COLOR_TO_IDX) + state] = 1

        return {
            'mission': obs['mission'],
            'image': out
        }

class RGBImgObsWrapper(gym.core.ObservationWrapper):
    """
    Wrapper to use fully observable RGB image as the only observation output,
    no language/mission. This can be used to have the agent to solve the
    gridworld in pixel space.
    """

    def __init__(self, env, tile_size=8):
        super().__init__(env)

        self.tile_size = tile_size

        self.observation_space.spaces['image'] = spaces.Box(
            low=0,
            high=255,
            shape=(self.env.width * tile_size, self.env.height * tile_size, 3),
            dtype='uint8'
        )

    def observation(self, obs):
        env = self.unwrapped

        rgb_img = env.render(
            mode='rgb_array',
            highlight=False,
            tile_size=self.tile_size
        )

        return {
            'mission': obs['mission'],
            'image': rgb_img
        }


class RGBImgPartialObsWrapper(gym.core.ObservationWrapper):
    """
    Wrapper to use partially observable RGB image as the only observation output
    This can be used to have the agent to solve the gridworld in pixel space.
    """

    def __init__(self, env, tile_size=8):
        super().__init__(env)

        self.tile_size = tile_size

        obs_shape = env.observation_space.spaces['image'].shape
        self.observation_space.spaces['image'] = spaces.Box(
            low=0,
            high=255,
            shape=(obs_shape[0] * tile_size, obs_shape[1] * tile_size, 3),
            dtype='uint8'
        )

    def observation(self, obs):
        env = self.unwrapped

        rgb_img_partial = env.get_obs_render(
            obs['image'],
            tile_size=self.tile_size
        )

        return {
            'mission': obs['mission'],
            'image': rgb_img_partial
        }

class FullyObsWrapper(gym.core.ObservationWrapper):
    """
    Fully observable gridworld using a compact grid encoding
    """

    def __init__(self, env):
        super().__init__(env)

        self.observation_space.spaces["image"] = spaces.Box(
            low=0,
            high=255,
            shape=(self.env.width, self.env.height, 3),  # number of cells
            dtype='uint8'
        )

    def observation(self, obs):
        env = self.unwrapped
        full_grid = env.grid.encode()
        full_grid[env.agent_pos[0]][env.agent_pos[1]] = np.array([
            OBJECT_TO_IDX['agent'],
            COLOR_TO_IDX['red'],
            env.agent_dir
        ])

        return {
            'mission': obs['mission'],
            'image': full_grid
        }

class FlatObsWrapper(gym.core.ObservationWrapper):
    """
    Encode mission strings using a one-hot scheme,
    and combine these with observed images into one flat array
    """

    def __init__(self, env, maxStrLen=96):
        super().__init__(env)

        self.maxStrLen = maxStrLen
        self.numCharCodes = 27

        imgSpace = env.observation_space.spaces['image']
        imgSize = reduce(operator.mul, imgSpace.shape, 1)

        self.observation_space = spaces.Box(
            low=0,
            high=255,
            shape=(imgSize + self.numCharCodes * self.maxStrLen,),
            dtype='uint8'
        )

        self.cachedStr = None
        self.cachedArray = None

    def observation(self, obs):
        image = obs['image']
        mission = obs['mission']

        # Cache the last-encoded mission string
        if mission != self.cachedStr:
            assert len(mission) <= self.maxStrLen, 'mission string too long ({} chars)'.format(len(mission))
            mission = mission.lower()

            strArray = np.zeros(shape=(self.maxStrLen, self.numCharCodes), dtype='float32')

            for idx, ch in enumerate(mission):
                if ch >= 'a' and ch <= 'z':
                    chNo = ord(ch) - ord('a')
                elif ch == ' ':
                    chNo = ord('z') - ord('a') + 1
                assert chNo < self.numCharCodes, '%s : %d' % (ch, chNo)
                strArray[idx, chNo] = 1

            self.cachedStr = mission
            self.cachedArray = strArray

        obs = np.concatenate((image.flatten(), self.cachedArray.flatten()))

        return obs

class ViewSizeWrapper(gym.core.Wrapper):
    """
    Wrapper to customize the agent field of view size.
    This cannot be used with fully observable wrappers.
    """

    def __init__(self, env, agent_view_size=7):
        super().__init__(env)

        assert agent_view_size % 2 == 1
        assert agent_view_size >= 3

        # Override default view size
        env.unwrapped.agent_view_size = agent_view_size

        # Compute observation space with specified view size
        observation_space = gym.spaces.Box(
            low=0,
            high=255,
            shape=(agent_view_size, agent_view_size, 3),
            dtype='uint8'
        )

        # Override the environment's observation space
        self.observation_space = spaces.Dict({
            'image': observation_space
        })

    def reset(self, **kwargs):
        return self.env.reset(**kwargs)

    def step(self, action):
        return self.env.step(action)

from .minigrid import Goal, Key, Ball
class DirectionObsWrapper(gym.core.ObservationWrapper):
    """
    Provides the slope/angular direction to the goal with the observations as modeled by (y2 - y2 )/( x2 - x1)
    type = {slope , angle}
    """
    def __init__(self, env,type='slope'):
        super().__init__(env)
        self.goal_position = None
        self.type = type

    def reset(self):
        obs = self.env.reset()
        if not self.goal_position:
            self.goal_position = [x for x,y in enumerate(self.grid.grid) if isinstance(y,(Goal) ) ]
            if len(self.goal_position) >= 1: # in case there are multiple goals , needs to be handled for other env types
                self.goal_position = (int(self.goal_position[0]/self.height) , self.goal_position[0]%self.width)
        return obs

    def observation(self, obs):
        slope = np.divide( self.goal_position[1] - self.agent_pos[1] ,  self.goal_position[0] - self.agent_pos[0])
        obs['goal_direction'] = np.arctan( slope ) if self.type == 'angle' else slope
        return obs

def is_traversable(tile):
    object_idx = tile[0]
    # Define indices of objects that are considered obstacles (e.g., walls).
    # Here, assuming 'wall' and 'lava' are not traversable.
    non_traversable_objects = {2, 9}  # Indices for 'wall' and 'lava'.
    return object_idx not in non_traversable_objects

def is_traversable_empty(tile):
    object_idx = tile[0]
    # Define indices of objects that are considered obstacles (e.g., walls).
    # Here, assuming 'wall' and 'lava' are not traversable.
    traversable_objects = {1}  # Indices for 'wall' and 'lava'.
    return object_idx in traversable_objects or (object_idx == 4 and tile[2] == 0) # Open door

from collections import deque
def bfs(grid, start, end):
    """
    Perform BFS to find the shortest path from start to end in a minigrid environment.
    Args:
    - grid (np.array): The grid represented as a numpy array of shape (n, m, 3).
    - start (tuple): Starting position (x, y).
    - end (tuple): Ending position (x, y).

    Returns:
    - path (list): List of tuples as coordinates for the shortest path, including start and end. 
                    Returns an empty list if no path is found.
    """
    queue = deque([start])
    paths = {start: [start]}
    directions = [(1, 0), (0, 1), (-1, 0), (0, -1)]  # Down, right, up, left
    while queue:
        current = queue.popleft()
        #if grid[current[0], current[1], 0] != 1:
        #    print("Current", current)#, paths[current])
        #    print("Content", grid[current[0], current[1]])
        if current == end:
            return paths[current]
        for direction in directions:
            neighbor = (current[0] + direction[0], current[1] + direction[1])
            if (0 <= neighbor[0] < grid.shape[0] and
                0 <= neighbor[1] < grid.shape[1] and
                neighbor not in paths and
                is_traversable(grid[neighbor])):
                paths[neighbor] = paths[current] + [neighbor]
                queue.append(neighbor)
    return []  # Return an empty list if no path is found

def bfs_empty(grid, start, end):
    """
    Perform BFS to find the shortest path from start to end in a minigrid environment.
    Args:
    - grid (np.array): The grid represented as a numpy array of shape (n, m, 3).
    - start (tuple): Starting position (x, y).
    - end (tuple): Ending position (x, y).

    Returns:
    - path (list): List of tuples as coordinates for the shortest path, including start and end. 
                    Returns an empty list if no path is found.
    """
    queue = deque([start])
    paths = {start: [start]}
    directions = [(1, 0), (0, 1), (-1, 0), (0, -1)]  # Down, right, up, left
    while queue:
        current = queue.popleft()
        #if grid[current[0], current[1], 0] != 1:
        #    print("Current", current)#, paths[current])
        #    print("Content", grid[current[0], current[1]])
        if current == end:
            return paths[current]
        for direction in directions:
            neighbor = (current[0] + direction[0], current[1] + direction[1])
            if (0 <= neighbor[0] < grid.shape[0] and
                0 <= neighbor[1] < grid.shape[1] and
                neighbor not in paths and
                (is_traversable_empty(grid[neighbor])) or (neighbor == end)):
                paths[neighbor] = paths[current] + [neighbor]
                queue.append(neighbor)
                #print("Current", neighbor, paths[neighbor])
                #print("Content", grid[neighbor[0], neighbor[1]])
    return []  # Return an empty list if no path is found

def get_position(grid, object_type, color=None):
    """
    Get the position of the object of the specified type in the grid.
    Args:
    - grid (np.array): The grid represented as a numpy array of shape (n, m, 3).
    - object_type (int): The type of the object to find.

    Returns:
    - position (tuple): The position of the object in the grid.
    """
    for i in range(grid.shape[0]):
        for j in range(grid.shape[1]):
            if grid[i, j, 0] == object_type:
                if color is not None and grid[i, j, 1] != color:
                    continue
                return (i, j)
    return None

def get_position_on_path(grid, agent_pos, final_pos, object_type, color=None, closed=None):
    path = bfs(grid, agent_pos, final_pos)
    for pos in path:
        if grid[pos[0], pos[1], 0] == object_type:
            if color is not None and grid[pos[0], pos[1], 1] != color:
                continue
            if closed is not None and grid[pos[0], pos[1], 2] != closed:
                continue
            return pos
    return None

class MultiRoomBfsBinningObsWrapper(gym.core.ObservationWrapper):
    """
    Provides the distance to the goal
    """
    def __init__(self, env,type='slope'):
        super().__init__(env)
        self.goal_position = None
        self.type = type

    def reset(self):
        obs = self.env.reset()
        if not self.goal_position:
            self.goal_position = [x for x,y in enumerate(self.grid.grid) if isinstance(y,(Goal) ) ]
            if len(self.goal_position) >= 1: # in case there are multiple goals , needs to be handled for other env types
                self.goal_position = (int(self.goal_position[0]/self.height) , self.goal_position[0]%self.width)
        return obs

    def observation(self, obs):
        self.goal_position = [x for x,y in enumerate(self.grid.grid) if isinstance(y,(Goal) ) ]
        if len(self.goal_position) >= 1: # in case there are multiple goals , needs to be handled for other env types
            self.goal_position = (self.goal_position[0]%self.width, int(self.goal_position[0]/self.height))
        #print("Goal position", self.goal_position, "Agent position", self.agent_pos)
        #dist = np.abs(self.goal_position[1] - self.agent_pos[1]) + np.abs(self.goal_position[0] - self.agent_pos[0])
        encoded_grid = self.grid.encode()
        if encoded_grid[self.goal_position[0], self.goal_position[1], 0] != 8:
            raise Exception("Goal position error")
        dist = len(bfs(encoded_grid, (self.agent_pos[0], self.agent_pos[1]), self.goal_position))
        if dist == 0:
            print(self.grid.encode()[self.goal_position[0], self.goal_position[1]])
            raise Exception("bfs error")
        door_count = ((encoded_grid[:,:,0] == 4)*(encoded_grid[:,:,2] == 1)).sum()
        #if dist > 35:
        #    dist = 35
        dist += door_count
        obs['goal_distance'] = dist
        return obs

class MultiRoomBinningObsWrapper(gym.core.ObservationWrapper):
    """
    Provides the distance to the goal
    """
    def __init__(self, env,type='slope'):
        super().__init__(env)
        self.goal_position = None
        self.type = type

    def reset(self):
        obs = self.env.reset()
        if not self.goal_position:
            self.goal_position = [x for x,y in enumerate(self.grid.grid) if isinstance(y,(Goal) ) ]
            if len(self.goal_position) >= 1: # in case there are multiple goals , needs to be handled for other env types
                self.goal_position = (int(self.goal_position[0]/self.height) , self.goal_position[0]%self.width)
        return obs

    def observation(self, obs):
        encoded_grid = self.grid.encode()
        door_count = ((encoded_grid[:,:,0] == 4)*(encoded_grid[:,:,2] == 1)).sum()
        obs['goal_distance'] = door_count
        return obs

class KeyCorridorBinningObsWrapper(gym.core.ObservationWrapper):
    """
    Provides the distance to the goal
    """
    def __init__(self, env,type='slope'):
        super().__init__(env)
        self.goal_position = None
        self.type = type

    def reset(self):
        obs = self.env.reset()
        if not self.goal_position:
            self.goal_position = [x for x,y in enumerate(self.grid.grid) if isinstance(y,(Goal) ) ]
            if len(self.goal_position) >= 1: # in case there are multiple goals , needs to be handled for other env types
                self.goal_position = (int(self.goal_position[0]/self.height) , self.goal_position[0]%self.width)
        return obs

    def observation(self, obs):
        key_position = [x for x,y in enumerate(self.grid.grid) if isinstance(y, Key)]
        if len(key_position) >= 1: # in case there are multiple goals , needs to be handled for other env types
            key_position = (key_position[0]%self.width, int(key_position[0]/self.height))
        ball_position = [x for x,y in enumerate(self.grid.grid) if isinstance(y, Ball)]
        if len(ball_position) >= 1: # in case there are multiple goals , needs to be handled for other env types
            ball_position = (ball_position[0]%self.width, int(ball_position[0]/self.height))

        #print("Goal position", self.goal_position, "Agent position", self.agent_pos)
        #dist = np.abs(self.goal_position[1] - self.agent_pos[1]) + np.abs(self.goal_position[0] - self.agent_pos[0])
        encoded_grid = self.grid.encode()
        ball_dist = len(bfs(encoded_grid, (self.agent_pos[0], self.agent_pos[1]), ball_position))
        key_dist = len(bfs(encoded_grid, (self.agent_pos[0], self.agent_pos[1]), key_position))

        # Put ball_dist and key_dist at respective maxes
        ball_dist = min(ball_dist, 12) # Was 8 for size 3/4
        key_dist = min(key_dist, 12) # Was 6 for size 3/4

        # Naively bin the two together, max distance is 8
        locked_door = ((encoded_grid[:,:,0] == 4)*(encoded_grid[:,:,2] == 2)).sum() > 0
        first_stage = (key_dist >= 2)*(ball_dist > 2)*locked_door
        dist = first_stage*(key_dist + 20) + (1 - first_stage)*(ball_dist + (int)(locked_door))
        if key_dist >= 2 and ball_dist == 2: # We successfully put the key somewhere else after opening the door, and are right next to the ball
            dist = 1 # Last step
        obs['goal_distance'] = dist
        return obs
    
# Try making one for obstructed door
class ObstructedMazeBinningObsWrapper(gym.core.ObservationWrapper):
    """
    Provides the distance to the goal
    """
    def __init__(self, env,type='slope'):
        super().__init__(env)
        self.goal_position = None
        self.type = type

    def reset(self):
        obs = self.env.reset()
        if not self.goal_position:
            self.goal_position = [x for x,y in enumerate(self.grid.grid) if isinstance(y,(Goal) ) ]
            if len(self.goal_position) >= 1: # in case there are multiple goals , needs to be handled for other env types
                self.goal_position = (int(self.goal_position[0]/self.height) , self.goal_position[0]%self.width)
        return obs

    def observation(self, obs):
        grid = self.grid.encode()
        # 1. Get the position of the blue ball
        ball_pos = get_position(grid, 6, COLOR_TO_IDX['blue'])
        # 2. Get the position of the agent
        agent_pos = (self.agent_pos[0], self.agent_pos[1])
        # 3. Do a bfs to the blue ball
        path = bfs(grid, agent_pos, ball_pos)
        # 4. Get the position of the door from the path
        door_pos = get_position_from_path(grid, path, 4, closed=2)
        # 5. Do a bfs_empty from the agent to the door
        path_door = bfs_empty(grid, agent_pos, door_pos)
        path_door = min(len(path_door), 5)
        # 6. Do a bfs_empty from the agent to the blue ball
        path_ball = bfs_empty(grid, agent_pos, ball_pos)
        path_ball = min(len(path_ball), 8)

        # Now construct bins from path_ball and path_door length
        obs['goal_distance'] = (path_door + 100)*(path_ball == 0) + path_ball*(path_ball > 0)
        #print("Bin", obs['goal_distance'])
        return obs

# Let's copy in the ones from GPT
# ObstructedMazeGPT
class BinningObsWrapper(gym.core.ObservationWrapper):
    """
    Provides the distance to the goal
    """
    def __init__(self, env,type='slope'):
        super().__init__(env)
        self.goal_position = None
        self.type = type

    def reset(self):
        obs = self.env.reset()
        if not self.goal_position:
            self.goal_position = [x for x,y in enumerate(self.grid.grid) if isinstance(y,(Goal) ) ]
            if len(self.goal_position) >= 1: # in case there are multiple goals , needs to be handled for other env types
                self.goal_position = (int(self.goal_position[0]/self.height) , self.goal_position[0]%self.width)
        return obs
    
    def progress_function(self) -> Tuple[List[float], List[bool]]:
        # Using indirect references to self, assume these functions and attributes belong to the self object

        # Get positions
        ball_pos = get_position(self.grid.encode(), OBJECT_TO_IDX['ball'], COLOR_TO_IDX['blue'])
        locked_door_pos = get_position_on_path(self.grid.encode(), (self.agent_pos[0], self.agent_pos[1]), ball_pos, OBJECT_TO_IDX['door'], closed=STATE_TO_IDX['locked'])
        
        progress_vars = []
        progress_directions = []
        
        if locked_door_pos:
            # If there is a locked door in path to ball
            # First, navigate towards the locked door as the intermediary target
            door_path_length = len(bfs(self.grid.encode(), (self.agent_pos[0], self.agent_pos[1]), locked_door_pos)) if locked_door_pos else float('inf')
            progress_vars.append(door_path_length)
            progress_directions.append(False) # False because we want the length of the path to decrease, reaching the door

            # Second, navigate towards the ball after passing the locked door
            ball_path_length = len(bfs(self.grid.encode(), (self.agent_pos[0], self.agent_pos[1]), ball_pos)) if ball_pos else float('inf')
            progress_vars.append(ball_path_length)
            progress_directions.append(False) # False as we want the path to the ball to decrease as well

        else:
            # Direct navigation to the ball if no locked door is on path
            ball_path_length = len(bfs(self.grid.encode(), (self.agent_pos[0], self.agent_pos[1]), ball_pos)) if ball_pos else float('inf')
            progress_vars.append(ball_path_length)
            progress_directions.append(False) # Direct path decreasing to the ball

        return progress_vars, progress_directions
        
    def observation(self, obs, max_progress):
        # Get progress vars
        progress_vars, _ = self.progress_function()
        
        # Replace infs and nans in progress_vars
        progress_vars = [0 if math.isnan(var) or math.isinf(var) else var for var in progress_vars]

        # Clip by max progress
        if max_progress is None:
            max_progress = [elem for elem in progress_vars]
        else:
            progress_vars = [min(var, max_progress[i]) for i, var in enumerate(progress_vars)]

        # Great, now we have the progress vars, let's bin them
        if len(progress_vars) == 1:
            obs['goal_distance'] = progress_vars[0]
        else:
            obs['goal_distance'] = progress_vars[0]*100 + progress_vars[1]*(progress_vars[0] == 0)

        return obs, max_progress

# MultiRoomGPT
class MultiRoomGPTBinningObsWrapper(gym.core.ObservationWrapper):
    """
    Provides the distance to the goal
    """
    def __init__(self, env,type='slope'):
        super().__init__(env)
        self.goal_position = None
        self.type = type

    def reset(self):
        obs = self.env.reset()
        if not self.goal_position:
            self.goal_position = [x for x,y in enumerate(self.grid.grid) if isinstance(y,(Goal) ) ]
            if len(self.goal_position) >= 1: # in case there are multiple goals , needs to be handled for other env types
                self.goal_position = (int(self.goal_position[0]/self.height) , self.goal_position[0]%self.width)
        return obs
    
    def progress_function(self):
        # Grid encoded as (h, w, 3) where each entry is (object_idx, color_idx, state_idx)
        grid = self.grid.encode()
        
        # Identify goal position (assuming there is only one goal in grid)
        goal_pos = get_position(grid, OBJECT_TO_IDX['goal'])
        
        if goal_pos is None:
            raise ValueError("No goal found in the grid.")
        
        # Use BFS to find the shortest path from agent's current position to goal position
        path = bfs(grid, (self.agent_pos[0], self.agent_pos[1]), goal_pos)
        
        # The length of the path is our most relevant variable for measuring progress.
        path_length = len(path) if path else float('inf')  # If no path found, consider distance infinitely large.

        # Returning the distance (to minimize, hence False as direction)
        return [path_length], [False]
    
    def observation(self, obs, max_progress):
        # Get progress vars
        progress_vars, _ = self.progress_function()
        
        # Replace infs and nans in progress_vars
        progress_vars = [0 if math.isnan(var) or math.isinf(var) else var for var in progress_vars]

        # Clip by max progress
        if max_progress is None:
            max_progress = [elem for elem in progress_vars]
        else:
            progress_vars = [min(var, max_progress[i]) for i, var in enumerate(progress_vars)]

        # Great, now we have the progress vars, let's bin them
        obs['goal_distance'] = progress_vars[0]

        return obs, max_progress

# Now impelement key + ball with GPT
class KeyCorridorGPTBinningObsWrapper(gym.core.ObservationWrapper):
    """
    Provides the distance to the goal
    """
    def __init__(self, env,type='slope'):
        super().__init__(env)
        self.goal_position = None
        self.type = type

    def reset(self):
        obs = self.env.reset()
        if not self.goal_position:
            self.goal_position = [x for x,y in enumerate(self.grid.grid) if isinstance(y,(Goal) ) ]
            if len(self.goal_position) >= 1: # in case there are multiple goals , needs to be handled for other env types
                self.goal_position = (int(self.goal_position[0]/self.height) , self.goal_position[0]%self.width)
        return obs
    
    def progress_function(self) -> Tuple[List[float], List[bool]]:
        # Get the grid
        grid = self.grid.encode()
        
        # Find positions of the key and the ball
        key_pos = get_position(grid, OBJECT_TO_IDX['key'])
        ball_pos = get_position(grid, OBJECT_TO_IDX['ball'])
        
        # Initialize progress variables and directions
        progress_vars = [float('inf'), float('inf')]  # inf means no valid target
        progress_directions = [False, False]  # We want the distance to decrease
        
        # Calculate path to the key and to the ball
        if key_pos:
            path_to_key = bfs(grid, (self.agent_pos[0], self.agent_pos[1]), key_pos)
            progress_vars[0] = len(path_to_key) if path_to_key else float('inf')
        else:
            # Assuming key is taken when key_pos is None
            progress_vars[0] = 0  # No distance to key needed, it's collected
        
        if ball_pos:
            path_to_ball = bfs(grid, (self.agent_pos[0], self.agent_pos[1]), ball_pos)
            progress_vars[1] = len(path_to_ball) if path_to_ball else float('inf')
        else:
            progress_vars[1] = 0  # No distance to ball if it's not in the grid (or already collected)

        # Return the progress variables and their respective directions
        return progress_vars, progress_directions
    
    def observation(self, obs, max_progress):
        # Get progress vars
        progress_vars, _ = self.progress_function()
        
        # Replace infs and nans in progress_vars
        progress_vars = [0 if math.isnan(var) or math.isinf(var) else var for var in progress_vars]

        # Clip by max progress
        if max_progress is None:
            max_progress = [elem for elem in progress_vars]
        else:
            progress_vars = [min(var, max_progress[i]) for i, var in enumerate(progress_vars)]

        # Great, now we have the progress vars, let's bin them
        obs['goal_distance'] = progress_vars[0] + 100*progress_vars[1]*(progress_vars[0] == 0)

        return obs, max_progress

